#!/usr/bin/python

import os, pprint, random, string, sys, tempfile

sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, "common"))
import driver, rdb_unittest, utils

try:
    get_char = unichr
except NameError:
    get_char = chr

def counter(i=0):
    while True:
        i += 1
        yield i

__unicode_alphabet = None
def randomString(length=None, max_length=100, use_unicode=False):
    global __unicode_alphabet
    if length is None:
        length = random.randint(1, max_length)
    
    if use_unicode and __unicode_alphabet is None:
        # this idea is from http://stackoverflow.com/questions/1477294/generate-random-utf-8-string-in-python
        __unicode_alphabet = [get_char(code_point) for lower, upper in [
            ( 0x0021, 0x0021 ), ( 0x0023, 0x0026 ), ( 0x0028, 0x007E ), ( 0x00A1, 0x00AC ),
            ( 0x00AE, 0x00FF ), ( 0x0100, 0x017F ), ( 0x0180, 0x024F ), ( 0x2C60, 0x2C7F ),
            ( 0x16A0, 0x16F0 ), ( 0x0370, 0x0377 ), ( 0x037A, 0x037E ), ( 0x0384, 0x038A ),
            ( 0x038C, 0x038C )
        ] for code_point in range(lower, upper + 1)]
    
    alphabet = __unicode_alphabet if use_unicode else string.ascii_uppercase + string.digits
    
    return ''.join(random.choice(alphabet) for _ in range(length))

class BackupTesting(rdb_unittest.RdbTestCase):
    servers           = 1
    tables            = 3
    recordsToGenerate = 200
    batchsize         = 40
    cleanTables       = False
    
    def compare_servers(self, targetConn, tables=None):
        self.maxDiff = None
        if isinstance(tables, (str, unicode)):
            tables = [tables]
        for db, table, primary_key in self.r.db('rethinkdb').table('table_config').pluck('db', 'name', 'primary_key').map(self.r.row.values()).run(self.conn):
            if tables is not None and table not in tables:
                continue
            
            # - confirm table existance and metadata
            targetInfo = None
            try:
                targetInfo = self.r.db(db).table(table).info().run(targetConn)
            except self.r.ReqlOpFailedError as e:
                raise Exception('Table %s.%s does not exist on the target server: %s' % (db, table, str(e)))
            
            # - compare the indexes
            actualIndexes = dict((x['index'], x) for x in self.r.db(db).table(table).index_status().run(targetConn))
            for expected in self.r.db(db).table(table).index_status().run(self.conn):
                self.assertTrue(expected['index'] in actualIndexes, 'Did not find the index < %s > in %s' % (expected['index'], actualIndexes))
                actual = actualIndexes[expected['index']]
                for value in ['function', 'geo', 'multi', 'outdated', 'query', 'ready']:
                    if not actual[value] == expected[value]:
                        self.error('On table %s.%s index %s the %s did not match, got < %s > when expected %s' % (db, table, expected['index'], value, str(actual[value]), str(expected[value])))
            
            # - compare the keys
            source = list(self.r.db(db).table(table).order_by(index=primary_key)[primary_key].run(self.conn))
            target = list(self.r.db(db).table(table).order_by(index=primary_key)[primary_key].run(targetConn))
            sys.stdout.flush()
            self.assertEqual(source, target)
            
            # - compare the data
            source = self.r.db(db).table(table).order_by(index=primary_key).run(self.conn)
            target = self.r.db(db).table(table).order_by(index=primary_key).run(targetConn)
            
            for expected, actual, i in zip(source, target, counter()):
                self.assertEqual(actual, expected, 'Expected did not match actual in table %s.%s item %d. Expected:\n%s\nvs. Actual:\n%s' % (db, table, i, pprint.pformat(expected), pprint.pformat(actual)))
            
            # ensure the cursors have been drained on both
            try:
                source.next()
            except self.r.ReqlCursorEmpty: pass
            else:
                raise Exception('In %s.%s there were more entries in the source than target (%d)' % (db, table, i))
            try:
                target.next()
            except self.r.ReqlCursorEmpty: pass
            else:
                raise Exception('In %s.%s there were more entries in the target than source (%d)' % (db, table, i))

class BackupRoundtrip(BackupTesting):
    def populateTable(self, conn=None, table=None, records=None, fieldName=None):
        if conn is None:
            conn = self.conn
        if table is None:
            table = self.table
        if records is None:
            records = self.recordsToGenerate
        batchsize = self.batchsize
        
        # - insert random data
        lastStep = 0
        inserted = 0
        for _ in range(0, records, batchsize):
            batch = []
            for _ in range(lastStep, min(records, lastStep + batchsize)):
                batch.append({
                    'id':       self.r.uuid(),
                    'bool':     random.choice((False, True)),
                    'integer':  random.randint(-500, 500),
                    'float':    random.uniform(-100000, 100000),
                    'string':   randomString(use_unicode=False),
                    'unicode':  randomString(use_unicode=True),
                    'dict': {
                        'a':    randomString(length=10, use_unicode=False),
                        'b':    randomString(length=10, use_unicode=True)
                    },
                    'array': [
                        randomString(length=10, use_unicode=False),
                        randomString(length=10, use_unicode=True)
                    ],
                    'complicated': {
                        'a': [randomString(length=10, use_unicode=False), random.randint(-500, 500), random.choice((False, True))],
                        '2': [
                            {
                                'a': randomString(length=10, use_unicode=False),
                                'b': random.randint(-500, 500),
                                'c': random.choice((False, True))
                            }, {
                                'a': randomString(length=10, use_unicode=False),
                                'b': random.randint(-500, 500),
                                'c': random.choice((False, True))
                            }
                        ]
                    },
                    'binary':    self.r.binary(bytes(bytearray(random.getrandbits(8) for _ in range(10)))),
                    'time':      self.r.time(
                                    random.randint(1400, 9999),
                                    random.randint(1, 12),
                                    random.randint(1, 28),
                                    random.randint(0, 23),
                                    random.randint(0, 59),
                                    random.uniform(0, 59),
                                    random.choice(('Z', '+01:00', '-05:00', '+10:00', '-11:00'))
                                ),
                    'circle':   self.r.circle([random.uniform(-90, 90), random.uniform(-180, 180)], random.uniform(1, 1000))
                })
            res = table.insert(batch).run(conn)
            self.assertEqual(res['errors'], 0, 'There were errors inserting data into table: %s' % res)
            inserted += res['inserted']
        self.assertEqual(inserted, records)
        
        # - create indexes
        table.index_create('bool').run(conn)
        table.index_create('unicode').run(conn)
        table.index_create('complicated').run(conn)
        table.index_create('function', self.r.row['dict']['a']).run(conn)
        table.index_wait().run(conn)
    
    def test_dump_restore(self):
        sourceServer  = self.cluster[0]
        targetCluster = driver.Cluster(initial_servers=1)
        targetServer  = targetCluster[0]
        targetConn    = self.r.connect(host=targetServer.host, port=str(targetServer.driver_port))
        
        # - dump to file
        dumpFile = tempfile.mktemp(suffix='.tar.gz')
        self.assertEqual(0, self.r._dump.main(argv=[
            '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
            '-f', dumpFile
        ]))
        utils.cleanupPathAtExit(dumpFile)
        
        # - restore to a second server
        self.assertEqual(0, self.r._restore.main(argv=[
            '--quiet', '--debug', '--host-name', targetServer.host, '--driver-port', str(targetServer.driver_port),
            dumpFile
        ]))
        
        # - validate that all tables match
        self.compare_servers(targetConn)
        
        # - restore again with --force
        self.assertEqual(0, self.r._restore.main(argv=[
            '--quiet', '--debug', '--host-name', targetServer.host, '--driver-port', str(targetServer.driver_port),
            dumpFile, '--force'
        ]))
    
    def test_export_import_json(self):
        sourceServer  = self.cluster[0]
        targetCluster = driver.Cluster(initial_servers=1)
        targetServer  = targetCluster[0]
        targetConn    = self.r.connect(host=targetServer.host, port=str(targetServer.driver_port))
        
        outputFolder  = tempfile.mkdtemp()
        utils.cleanupPathAtExit(outputFolder)
        
        for table in self.tableNames:
            export_folder = os.path.join(outputFolder, 'export_%s' % table)
            
            # - export the table
            self.assertEqual(0, self.r._export.main(argv=[
                '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
                '-d', export_folder, '-e', '%s.%s' % (self.dbName, table), '--format', 'json'
            ]))
            
            # - verify existance of files
            outputBasePath = os.path.join(export_folder, self.dbName, table)
            self.assertTrue(os.path.exists(outputBasePath + '.info'), 'There was no file at: %s' % (outputBasePath + '.info'))
            self.assertTrue(os.path.exists(outputBasePath + '.json'), 'There was no file at: %s' % (outputBasePath + '.json'))
            
            # - import the table
            self.assertEqual(0, self.r._import.main(argv=[
                '--quiet', '--debug', '--host-name', targetServer.host, '--driver-port', str(targetServer.driver_port),
                '-f', outputBasePath + '.json', '--table', '%s.%s' % (self.dbName, table)
            ]))
            
            # - validate that the table matches
            self.compare_servers(targetConn, tables=[table])
    
    def test_export_import_dir(self):
        sourceServer  = self.cluster[0]
        targetCluster = driver.Cluster(initial_servers=1)
        targetServer  = targetCluster[0]
        targetConn    = self.r.connect(host=targetServer.host, port=str(targetServer.driver_port))
        
        outputFolder  = tempfile.mkdtemp()
        utils.cleanupPathAtExit(outputFolder)
        outputFolder  = os.path.join(outputFolder, 'export')
        
        # - export the test database
        self.assertEqual(0, self.r._export.main(argv=[
            '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
            '-d', outputFolder, '-e', '%s' % (self.dbName), '--format', 'json'
        ]))
        
        # - verify existance of files
        for table in self.tableNames:
            outputBasePath = os.path.join(outputFolder, self.dbName, table)
            self.assertTrue(os.path.exists(outputBasePath + '.info'), 'There was no file at: %s' % (outputBasePath + '.info'))
            self.assertTrue(os.path.exists(outputBasePath + '.json'), 'There was no file at: %s' % (outputBasePath + '.json'))
        
        # - import the folder
        self.assertEqual(0, self.r._import.main(argv=[
            '--quiet', '--debug', '--host-name', targetServer.host, '--driver-port', str(targetServer.driver_port),
            '-d', outputFolder
        ]))
    
    def test_index_rebuild(self):
        sourceServer  = self.cluster[0]
        
        # - rebuild indexes
        self.assertEqual(0, self.r._index_rebuild.main(argv=[
            '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port)
        ]))
        
        # - rebuild indexes with force
        self.assertEqual(0, self.r._index_rebuild.main(argv=[
            '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
            '--force'
        ]))
        
        # - rebuild indexes on the specific table with force
        self.assertEqual(0, self.r._index_rebuild.main(argv=[
            '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
            '--force', '--rebuild', '%s.%s' % (self.dbName, self.tableName)
        ]))

class ExportCSV(BackupTesting):
    def populateTable(self, conn=None, table=None, records=None, fieldName=None):
        if conn is None:
            conn = self.conn
        if table is None:
            table = self.table
        if records is None:
            records = self.recordsToGenerate
        
        res = table.insert(conn._r.range(1, records + 1).map({fieldName:conn._r.row.coerce_to('STRING')})).run(conn)
        self.assertEqual(res['errors'], 0, 'There were errors inserting data into table: %s' % res)
        self.assertEqual(res['inserted'], records)
        return records
    
    def test_export_import_csv(self):
        sourceServer  = self.cluster[0]
        targetCluster = driver.Cluster(initial_servers=1)
        targetServer  = targetCluster[0]
        targetConn    = self.r.connect(host=targetServer.host, port=str(targetServer.driver_port))
        
        outputFolder  = tempfile.mkdtemp()
        utils.cleanupPathAtExit(outputFolder)
        
        for table in self.tableNames:
            export_folder = os.path.join(outputFolder, 'export_%s' % table)
            
            # - export the table
            self.assertEqual(0, self.r._export.main(argv=[
                '--quiet', '--debug', '--host-name', sourceServer.host, '--driver-port', str(sourceServer.driver_port),
                '-d', export_folder, '-e', '%s.%s' % (self.dbName, table), '--format', 'csv',
                '--fields', 'id,bool,integer,float,string,unicode,dict,array,complicated,binary,time,circle'
            ]))
            
            # - verify existance of files
            outputBasePath = os.path.join(export_folder, self.dbName, table)
            self.assertTrue(os.path.exists(outputBasePath + '.info'), 'There was no file at: %s' % (outputBasePath + '.info'))
            self.assertTrue(os.path.exists(outputBasePath + '.csv'), 'There was no file at: %s' % (outputBasePath + '.csv'))
            
            # - import the table
            self.assertEqual(0, self.r._import.main(argv=[
                '--quiet', '--debug', '--host-name', targetServer.host, '--driver-port', str(targetServer.driver_port),
                '-f', outputBasePath + '.csv', '--table', '%s.%s' % (self.dbName, table)
            ]))
        
            # - validate that the data matches
            self.compare_servers(targetConn, tables=[table])

if __name__ == '__main__':
    rdb_unittest.main()
